--- ../../plenoxels/opt/opt.py	2024-02-13 12:06:07.926166025 -0500
+++ ../google/plenoxels/opt/opt.py	2024-02-13 10:46:24.400645216 -0500
@@ -3,6 +3,23 @@
 # First, install svox2
 # Then, python opt.py <path_to>/nerf_synthetic/<scene> -t ckpt/<some_name>
 # or use launching script:   sh launch.sh <EXP_NAME> <GPU> <DATA_DIR>
+import sys, os
+# to avoid lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' 
+import cv2
+sys.path.append('../..')
+if 'DEBUG_SESSION' in os.environ:
+    from my_python_utils.common_utils import *
+
+def str2bool(v):
+  assert type(v) is str
+  if v.lower() in ('yes', 'true', 't', 'y', '1'):
+    return True
+  elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+    return False
+  else:
+    raise argparse.ArgumentTypeError('Boolean (yes, true, t, y or 1, lower or upper case) string expected.')
+from copy import deepcopy
+
 import torch
 import torch.cuda
 import torch.optim
@@ -19,8 +36,22 @@
 import argparse
 import cv2
 from util.dataset import datasets
-from util.util import Timing, get_expon_lr_func, generate_dirs_equirect, viridis_cmap
 from util import config_util
+from util.util import Timing, get_expon_lr_func, generate_dirs_equirect, viridis_cmap
+from multiscale import *
+
+try:
+  import cPickle as pickle
+except:
+  import _pickle as pickle
+
+def dump_to_pickle(filename, obj):
+  try:
+    with open(filename, 'wb') as handle:
+      pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
+  except:
+    with open(filename, 'wb') as handle:
+      pickle.dump(obj, handle)
 
 from warnings import warn
 from datetime import datetime
@@ -29,6 +60,7 @@
 from tqdm import tqdm
 from typing import NamedTuple, Optional, Union
 
+
 device = "cuda" if torch.cuda.is_available() else "cpu"
 
 parser = argparse.ArgumentParser()
@@ -40,13 +72,18 @@
                      help='checkpoint and logging directory')
 
 group.add_argument('--reso',
-                        type=str,
-                        default=
-                        "[[256, 256, 256], [512, 512, 512]]",
-                       help='List of grid resolution (will be evaled as json);'
+                    type=str,
+                    default="[[256, 256, 256], [512, 512, 512]]",
+                    help='List of grid resolution (will be evaled as json);'
                             'resamples to the next one every upsamp_every iters, then ' +
                             'stays at the last one; ' +
                             'should be a list where each item is a list of 3 ints or an int')
+
+group.add_argument('--base_reso',
+                    type=int,
+                    default=16,
+                    help='Base resolution to start with (for upsampling), when using multiscale')
+
 group.add_argument('--upsamp_every', type=int, default=
                      3 * 12800,
                     help='upsample the grid every x iters')
@@ -65,6 +102,11 @@
 group.add_argument('--basis_reso', type=int, default=32,
                    help='basis grid resolution (only for learned texture)')
 group.add_argument('--sh_dim', type=int, default=9, help='SH/learned basis dimensions (at most 10)')
+group.add_argument('--sh_dim_per_scale',
+                    type=str,
+                    default="None",
+                    help='List of sh per scale')
+
 
 group.add_argument('--mlp_posenc_size', type=int, default=4, help='Positional encoding size if using MLP basis; 0 to disable')
 group.add_argument('--mlp_width', type=int, default=32, help='MLP width if using MLP basis')
@@ -78,7 +120,8 @@
 group = parser.add_argument_group("optimization")
 group.add_argument('--n_iters', type=int, default=10 * 12800, help='total number of iters to optimize for')
 group.add_argument('--batch_size', type=int, default=
-                     5000,
+                     #5000,
+                     128 ** 2,
                      #100000,
                      #  2000,
                    help='batch size')
@@ -86,7 +129,7 @@
 
 # TODO: make the lr higher near the end
 group.add_argument('--sigma_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Density optimizer")
-group.add_argument('--lr_sigma', type=float, default=3e1, help='SGD/rmsprop lr for sigma')
+group.add_argument('--lr_sigma', type=float, default=3e1, help='SGD/rmsprop lr for sigma') # originally 3e1
 group.add_argument('--lr_sigma_final', type=float, default=5e-2)
 group.add_argument('--lr_sigma_decay_steps', type=int, default=250000)
 group.add_argument('--lr_sigma_delay_steps', type=int, default=15000,
@@ -144,7 +187,7 @@
 group.add_argument('--rms_beta', type=float, default=0.95, help="RMSProp exponential averaging factor")
 
 group.add_argument('--print_every', type=int, default=20, help='print every')
-group.add_argument('--save_every', type=int, default=5,
+group.add_argument('--save_every', type=int, default=1,
                    help='save every x epochs')
 group.add_argument('--eval_every', type=int, default=1,
                    help='evaluate every x epochs')
@@ -162,6 +205,12 @@
 group.add_argument('--log_depth_map_use_thresh', type=float, default=None,
         help="If specified, uses the Dex-neRF version of depth with given thresh; else returns expected term")
 
+group.add_argument('--debug', type=str2bool, default="False",
+        help="If True, will run in debug mode and show things on visdom server")
+
+group.add_argument('--preserve_previous_grid_values', type=str2bool, default="False",
+        help="When upsampling, preserve the grid values from previous scales, instead of recomputing them from the current grid")
+
 
 group = parser.add_argument_group("misc experiments")
 group.add_argument('--thresh_type',
@@ -171,7 +220,7 @@
 group.add_argument('--weight_thresh', type=float,
                     default=0.0005 * 512,
                     #  default=0.025 * 512,
-                   help='Upsample weight threshold; will be divided by resulting z-resolution')
+                    help='Upsample weight threshold; will be divided by resulting z-resolution')
 group.add_argument('--density_thresh', type=float,
                     default=5.0,
                    help='Upsample sigma threshold')
@@ -179,9 +228,9 @@
                     default=1.0+1e-9,
                    help='Background sigma threshold for sparsification')
 group.add_argument('--max_grid_elements', type=int,
-                    default=44_000_000,
+                    default=22_000_000,
                    help='Max items to store after upsampling '
-                        '(the number here is given for 22GB memory)')
+                        '(the number here is given for 22GB memory)') # we reduce it to 11M for the multiscale
 
 group.add_argument('--tune_mode', action='store_true', default=False,
                    help='hypertuning mode (do not save, for speed)')
@@ -200,7 +249,7 @@
 group.add_argument('--lambda_tv_sh', type=float, default=1e-3)
 group.add_argument('--tv_sh_sparsity', type=float, default=0.01)
 
-group.add_argument('--lambda_tv_lumisphere', type=float, default=0.0)#1e-2)#1e-3)
+group.add_argument('--lambda_tv_lumisphere', type=float, default=0.0) #1e-2)#1e-3)
 group.add_argument('--tv_lumisphere_sparsity', type=float, default=0.01)
 group.add_argument('--tv_lumisphere_dir_factor', type=float, default=0.0)
 
@@ -240,10 +289,27 @@
 group.add_argument('--lr_decay', action='store_true', default=True)
 
 group.add_argument('--n_train', type=int, default=None, help='Number of training images. Defaults to use all avaiable.')
+group.add_argument('--n_train_percentage', type=float, default=None, help='Number of training images. Defaults to use all avaiable.')
 
 group.add_argument('--nosphereinit', action='store_true', default=False,
                      help='do not start with sphere bounds (please do not use for 360)')
 
+group = parser.add_argument_group("my_args")
+group.add_argument('--use-pytorch', default='True', type=str2bool, help='do not start with sphere bounds (please do not use for 360)')
+group.add_argument('--no-dilate', default='False', type=str2bool, help='do not start with sphere bounds (please do not use for 360)')
+
+group.add_argument('--multiscale', default='False', type=str2bool, help='whether to use multiscale')
+group.add_argument('--multiscale_sigma', default='False', type=str2bool, help='whether to use multiscale for sigma')
+group.add_argument('--propagate_forward', default='False', type=str2bool, help='whether to use multiscale for sigma')
+
+group.add_argument('--multiscale_gradients', default='True', type=str2bool, help='whether to use multiscale for sigma')
+group.add_argument('--multiply_grads_per_scale', default='True', type=str2bool, help='Take into account that coarser resolutions should influence more the gradient')
+group.add_argument('--froze_previous_scales', default='False', type=str2bool, help='Froze previous scales fited')
+
+
+group.add_argument('--masking_percentage', default='-1', type=float, help='masking percentage of training images. if < 0 is not used.')
+
+
 args = parser.parse_args()
 config_util.maybe_merge_config_file(args)
 
@@ -251,10 +317,29 @@
 assert args.lr_sh_final <= args.lr_sh, "lr_sh must be >= lr_sh_final"
 assert args.lr_basis_final <= args.lr_basis, "lr_basis must be >= lr_basis_final"
 
+
+IMSHOW_BIGGEST_DIM=256
+MULTISCALE = args.multiscale or args.multiscale_sigma
+
 os.makedirs(args.train_dir, exist_ok=True)
-summary_writer = SummaryWriter(args.train_dir)
+summary_writer = SummaryWriter(args.train_dir.replace('/gcs/', 'gs://'))
 
 reso_list = json.loads(args.reso)
+if not args.sh_dim_per_scale == 'None' and MULTISCALE:
+    sh_dim_per_scale = json.loads(args.sh_dim_per_scale)
+    reso_list_all = []
+    reso = args.base_reso
+    while reso < reso_list[0][0]:
+        reso_list_all += [[reso, reso, reso]]
+        reso *= 2
+    reso_list_all += reso_list
+    assert len(sh_dim_per_scale) == len(reso_list_all), \
+        "Must specify sh_dim_per_scale for each scale {}, taking into account that all resolutions are {}".format(sh_dim_per_scale, 
+                                                                                                                                                                  reso_list_all)
+else:
+    sh_dim_per_scale = None
+
+args.sh_dim_per_scale = sh_dim_per_scale
 reso_id = 0
 
 with open(path.join(args.train_dir, 'args.json'), 'w') as f:
@@ -262,17 +347,33 @@
     # Changed name to prevent errors
     shutil.copyfile(__file__, path.join(args.train_dir, 'opt_frozen.py'))
 
-torch.manual_seed(20200823)
-np.random.seed(20200823)
+torch.manual_seed(20200823) # Inside the beavers dam?
+np.random.seed(20200823) # Inside the beavers dam?
 
 factor = 1
+
+assert (not args.n_train is None) + (not args.n_train_percentage is None) <= 1, "Cannot specify both n_train and percentage of images to use" 
+assert args.n_train_percentage is None or 0 < args.n_train_percentage <= 1, "Percentage of images to use must be between 0 and 1, if specified"
+
+
 dset = datasets[args.dataset_type](
                args.data_dir,
                split="train",
                device=device,
                factor=factor,
                n_images=args.n_train,
+               masking_percentage=args.masking_percentage,
                **config_util.build_data_options(args))
+if not args.n_train_percentage is None:
+    n_images_by_percentage = max(1, int(args.n_train_percentage * dset.n_images))
+    dset = datasets[args.dataset_type](
+               args.data_dir,
+               split="train",
+               device=device,
+               factor=factor,
+               n_images=n_images_by_percentage,
+               **config_util.build_data_options(args))
+    
 
 if args.background_nlayers > 0 and not dset.should_use_background:
     warn('Using a background model for dataset type ' + str(type(dset)) + ' which typically does not use background')
@@ -280,9 +381,24 @@
 dset_test = datasets[args.dataset_type](
         args.data_dir, split="test", **config_util.build_data_options(args))
 
+
+def display_grid(grid, title=''):
+    x = np.linspace(0, 1, grid.shape[0])
+    y = np.linspace(0, 1, grid.shape[1])
+    z = np.linspace(0, 1, grid.shape[2])
+
+    # Creating the grid using meshgrid
+    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
+    valid_nodes = tonumpy(grid.links >= 0 )
+    X, Y, Z = X[valid_nodes], Y[valid_nodes], Z[valid_nodes]
+    coords = np.concatenate((X[:,None],Y[:,None],Z[:,None]), -1)
+    show_pointcloud(coords, title=title)        
+
 global_start_time = datetime.now()
+previous_grids = []
 
-grid = svox2.SparseGrid(reso=reso_list[reso_id],
+if not MULTISCALE:
+    grid = svox2.SparseGrid(reso=reso_list[reso_id],
                         center=dset.scene_center,
                         radius=dset.scene_radius,
                         use_sphere_bound=dset.use_sphere_bound and not args.nosphereinit,
@@ -296,6 +412,54 @@
                         background_nlayers=args.background_nlayers,
                         background_reso=args.background_reso)
 
+else:        
+    # create initial scales
+    reso_list_multiscale = []
+    reso = args.base_reso
+    while reso < reso_list[0][0]:
+        reso_list_multiscale += [[reso, reso, reso]]
+        reso *= 2
+    reso_list_multiscale += [reso_list[0]]
+    grid = svox2.SparseGrid(reso=reso_list_multiscale[0],
+                    center=dset.scene_center,
+                    radius=dset.scene_radius,
+                    use_sphere_bound=dset.use_sphere_bound and not args.nosphereinit,
+                    basis_dim=args.sh_dim,
+                    use_z_order=True,
+                    device=device,
+                    basis_reso=args.basis_reso,
+                    basis_type=svox2.__dict__['BASIS_TYPE_' + args.basis_type.upper()],
+                    mlp_posenc_size=args.mlp_posenc_size,
+                    mlp_width=args.mlp_width,
+                    background_nlayers=args.background_nlayers,
+                    background_reso=args.background_reso)
+    if args.debug:
+        display_grid(grid, title='initial_grid_{}'.format(grid.shape))
+    for cur_res_id in range(1, len(reso_list_multiscale)):
+        z_reso = reso_list_multiscale[cur_res_id][2]
+
+        previous_grid = deepcopy(grid)
+        # we initialize
+        grid.density_data.data[:] = 10
+        grid.resample(reso=reso_list_multiscale[cur_res_id],
+                      sigma_thresh=0,
+                      weight_thresh=-1,
+                      dilate=0, #use_sparsify,
+                      cameras=None,
+                      max_elements=args.max_grid_elements)
+        grid.density_data.data[:] = 0
+        if args.debug:
+            display_grid(grid, title='initial_grid_{}'.format(grid.shape))
+
+        with torch.no_grad():
+            p_to_current_map = get_links_to_current_grid(grid, previous_grid)
+            if args.multiscale:
+                previous_grid.sh_data -= previous_grid.sh_data
+            if args.multiscale_sigma:
+                previous_grid.density_data -= previous_grid.density_data
+            previous_grids = [(previous_grid,p_to_current_map), *previous_grids]
+grid.multiscale = MULTISCALE
+    
 # DC -> gray; mind the SH scaling!
 grid.sh_data.data[:] = 0.0
 grid.density_data.data[:] = 0.0 if args.lr_fg_begin_step > 0 else args.init_sigma
@@ -345,6 +509,7 @@
                      ndc_coeffs=dset.ndc_coeffs) for i, c2w in enumerate(dset.c2w)
     ]
 ckpt_path = path.join(args.train_dir, 'ckpt.npz')
+results_path = path.join(args.train_dir, 'results.pkl')
 
 lr_sigma_func = get_expon_lr_func(args.lr_sigma, args.lr_sigma_final, args.lr_sigma_delay_steps,
                                   args.lr_sigma_delay_mult, args.lr_sigma_decay_steps)
@@ -361,25 +526,32 @@
 lr_basis_factor = 1.0
 
 last_upsamp_step = args.init_iters
+    
 
+    
 if args.enable_random:
     warn("Randomness is enabled for training (normal for LLFF & scenes with background)")
 
+all_train_stats = []
+all_eval_stats = []
+
 epoch_id = -1
+test_cam = None
 while True:
     dset.shuffle_rays()
     epoch_id += 1
     epoch_size = dset.rays.origins.size(0)
     batches_per_epoch = (epoch_size-1)//args.batch_size+1
     # Test
-    def eval_step():
+    def eval_step(plot_prefix=''):
         # Put in a function to avoid memory leak
+        global test_cam
         print('Eval step')
         with torch.no_grad():
             stats_test = {'psnr' : 0.0, 'mse' : 0.0}
 
             # Standard set
-            N_IMGS_TO_EVAL = min(20 if epoch_id > 0 else 5, dset_test.n_images)
+            N_IMGS_TO_EVAL = dset_test.n_images if not args.debug else 1
             N_IMGS_TO_SAVE = N_IMGS_TO_EVAL # if not args.tune_mode else 1
             img_eval_interval = dset_test.n_images // N_IMGS_TO_EVAL
             img_save_interval = (N_IMGS_TO_EVAL // N_IMGS_TO_SAVE)
@@ -403,8 +575,27 @@
                                    width=dset_test.get_image_size(img_id)[1],
                                    height=dset_test.get_image_size(img_id)[0],
                                    ndc_coeffs=dset_test.ndc_coeffs)
-                rgb_pred_test = grid.volume_render_image(cam, use_kernel=True)
+                test_cam = cam
+                if i == 0 and MULTISCALE:
+                    propagate_multiscale_to_parents(grid, previous_grids, args, sh_dim_per_scale)
+                    for g, _ in [*previous_grids, (grid, None)]:
+                        rgb_pred_test_grid = g.volume_render_image(cam) #, use_pytorch=args.use_pytorch)
+                        rgb_pred_test_grid = np.clip(rgb_pred_test_grid.cpu().numpy(), a_min=0, a_max=1)
+                        summary_writer.add_image(plot_prefix + 'rgb_pred_test_grid_res_{}'.format(g.shape[0]),
+                                                 rgb_pred_test_grid, global_step=gstep_id_base, dataformats='HWC')
+                   
+                        if args.debug:
+                            imshow(rgb_pred_test_grid, title=plot_prefix + 'rgb_pred_test_grid_res_{}'.format(g.shape[0]),biggest_dim=IMSHOW_BIGGEST_DIM)
+                    propagate_multiscale_to_leaves(grid, previous_grids, args, sh_dim_per_scale)
+                rgb_pred_test = grid.volume_render_image(cam) #, use_pytorch=args.use_pytorch)
                 rgb_gt_test = dset_test.gt[img_id].to(device=device)
+                
+                if args.debug:
+                    to_plot = np.clip(rgb_pred_test.cpu().numpy(), a_min=0, a_max=1)
+
+                    imshow(to_plot, title=plot_prefix + 'rgb_pred_test', biggest_dim=IMSHOW_BIGGEST_DIM)
+                    imshow(rgb_gt_test.cpu().numpy(), title=plot_prefix + 'rgb_gt_test', biggest_dim=IMSHOW_BIGGEST_DIM)
+
                 all_mses = ((rgb_gt_test - rgb_pred_test) ** 2).cpu()
                 if i % img_save_interval == 0:
                     img_pred = rgb_pred_test.cpu()
@@ -415,16 +606,20 @@
                         mse_img = all_mses / all_mses.max()
                         summary_writer.add_image(f'test/mse_map_{img_id:04d}',
                                 mse_img, global_step=gstep_id_base, dataformats='HWC')
+
                     if args.log_depth_map:
                         depth_img = grid.volume_render_depth_image(cam,
                                     args.log_depth_map_use_thresh if
                                     args.log_depth_map_use_thresh else None
                                 )
-                        depth_img = viridis_cmap(depth_img.cpu())
+                        depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
+                        # depth_img = viridis_cmap(depth_img.cpu())
                         summary_writer.add_image(f'test/depth_map_{img_id:04d}',
-                                depth_img,
+                                torch.cat([depth_img[...,None]]*3, -1),
                                 global_step=gstep_id_base, dataformats='HWC')
-
+                        if args.debug:
+                            imshow(depth_img, title='depth_img', biggest_dim=IMSHOW_BIGGEST_DIM)
+                    
                 rgb_pred_test = rgb_gt_test = None
                 mse_num : float = all_mses.mean().item()
                 psnr = -10.0 * math.log10(mse_num)
@@ -435,29 +630,6 @@
                 stats_test['psnr'] += psnr
                 n_images_gen += 1
 
-            if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE or \
-               grid.basis_type == svox2.BASIS_TYPE_MLP:
-                 # Add spherical map visualization
-                EQ_RESO = 256
-                eq_dirs = generate_dirs_equirect(EQ_RESO * 2, EQ_RESO)
-                eq_dirs = torch.from_numpy(eq_dirs).to(device=device).view(-1, 3)
-
-                if grid.basis_type == svox2.BASIS_TYPE_MLP:
-                    sphfuncs = grid._eval_basis_mlp(eq_dirs)
-                else:
-                    sphfuncs = grid._eval_learned_bases(eq_dirs)
-                sphfuncs = sphfuncs.view(EQ_RESO, EQ_RESO*2, -1).permute([2, 0, 1]).cpu().numpy()
-
-                stats = [(sphfunc.min(), sphfunc.mean(), sphfunc.max())
-                        for sphfunc in sphfuncs]
-                sphfuncs_cmapped = [viridis_cmap(sphfunc) for sphfunc in sphfuncs]
-                for im, (minv, meanv, maxv) in zip(sphfuncs_cmapped, stats):
-                    cv2.putText(im, f"{minv=:.4f} {meanv=:.4f} {maxv=:.4f}", (10, 20),
-                                0, 0.5, [255, 0, 0])
-                sphfuncs_cmapped = np.concatenate(sphfuncs_cmapped, axis=0)
-                summary_writer.add_image(f'test/spheric',
-                        sphfuncs_cmapped, global_step=gstep_id_base, dataformats='HWC')
-                # END add spherical map visualization
 
             stats_test['mse'] /= n_images_gen
             stats_test['psnr'] /= n_images_gen
@@ -466,15 +638,27 @@
                         stats_test[stat_name], global_step=gstep_id_base)
             summary_writer.add_scalar('epoch_id', float(epoch_id), global_step=gstep_id_base)
             print('eval stats:', stats_test)
-    if epoch_id % max(factor, args.eval_every) == 0: #and (epoch_id > 0 or not args.tune_mode):
+            
+            return stats_test
+            
+    if True and (epoch_id % max(factor, args.eval_every) == 0): #and (epoch_id > 0 or not args.tune_mode):
         # NOTE: we do an eval sanity check, if not in tune_mode
-        eval_step()
+        eval_stats = eval_step()
+        all_eval_stats.append(('epoch_{}'.format(epoch_id), eval_stats))
         gc.collect()
 
+    
     def train_step():
         print('Train step')
+        global test_cam
+        all_depth_imgs = []
+        all_color_imgs = []
+        previous_values = None
+        debug_info = {'all_depth_imgs': all_depth_imgs,
+                      'all_color_imgs': all_color_imgs}
         pbar = tqdm(enumerate(range(0, epoch_size, args.batch_size)), total=batches_per_epoch)
         stats = {"mse" : 0.0, "psnr" : 0.0, "invsqr_mse" : 0.0}
+        epoch_stats = {"mse" : 0.0, "psnr" : 0.0}
         for iter_id, batch_begin in pbar:
             gstep_id = iter_id + gstep_id_base
             if args.lr_fg_begin_step > 0 and gstep_id == args.lr_fg_begin_step:
@@ -494,13 +678,23 @@
             batch_dirs = dset.rays.dirs[batch_begin: batch_end]
             rgb_gt = dset.rays.gt[batch_begin: batch_end]
             rays = svox2.Rays(batch_origins, batch_dirs)
-
+            
+            if args.propagate_forward:
+                # aggregate information and propagate
+                #if args.debug:
+                #    grid_values_before = tonumpy(grid.sh_data)
+                propagate_multiscale_to_parents(grid, previous_grids, args, sh_dim_per_scale)
+                propagate_multiscale_to_leaves(grid, previous_grids, args, sh_dim_per_scale)
+                #if args.debug:
+                #    grid_values_after = tonumpy(grid.sh_data)
+                
             #  with Timing("volrend_fused"):
             rgb_pred = grid.volume_render_fused(rays, rgb_gt,
                     beta_loss=args.lambda_beta,
                     sparsity_loss=args.lambda_sparsity,
                     randomize=args.enable_random)
 
+
             #  with Timing("loss_comp"):
             mse = F.mse_loss(rgb_gt, rgb_pred)
 
@@ -510,6 +704,9 @@
             stats['mse'] += mse_num
             stats['psnr'] += psnr
             stats['invsqr_mse'] += 1.0 / mse_num ** 2
+            
+            epoch_stats['mse'] += mse_num
+            epoch_stats['psnr'] += psnr
 
             if (iter_id + 1) % args.print_every == 0:
                 # Print averaged stats
@@ -541,16 +738,7 @@
                 if args.weight_decay_sigma < 1.0:
                     grid.density_data.data *= args.weight_decay_sh
 
-            #  # For outputting the % sparsity of the gradient
-            #  indexer = grid.sparse_sh_grad_indexer
-            #  if indexer is not None:
-            #      if indexer.dtype == torch.bool:
-            #          nz = torch.count_nonzero(indexer)
-            #      else:
-            #          nz = indexer.size()
-            #      with open(os.path.join(args.train_dir, 'grad_sparsity.txt'), 'a') as sparsity_file:
-            #          sparsity_file.write(f"{gstep_id} {nz}\n")
-
+            # Main two regularizations
             # Apply TV/Sparsity regularizers
             if args.lambda_tv > 0.0:
                 #  with Timing("tv_inpl"):
@@ -567,6 +755,8 @@
                         sparse_frac=args.tv_sh_sparsity,
                         ndc_coeffs=dset.ndc_coeffs,
                         contiguous=args.tv_contiguous)
+            # Main two regularizations
+            
             if args.lambda_tv_lumisphere > 0.0:
                 grid.inplace_tv_lumisphere_grad(grid.sh_data.grad,
                         scaling=args.lambda_tv_lumisphere,
@@ -590,6 +780,26 @@
             #        ' sh', torch.count_nonzero(grid.sparse_sh_grad_indexer).item())
 
             # Manual SGD/rmsprop step
+            
+            if args.multiscale_gradients:
+                # masks out by sh_dim_per_scale if necessary
+                propagate_multiscale_to_parents(grid, previous_grids, args, sh_dim_per_scale, 
+                                                grad=True)
+                if args.multiply_grads_per_scale:
+                    with torch.no_grad():
+                        froze_factor = 0 if args.froze_previous_scales and epoch_id > 0 else 1
+                        base_scale = args.base_reso
+                        for previous_grid in previous_grids:
+                            grid_scale = previous_grid[0].shape[0]
+                            previous_grid[0].density_data.grad *= base_scale / grid_scale * froze_factor
+                            previous_grid[0].sh_data.grad *= base_scale / grid_scale * froze_factor
+                        grid.density_data.grad *= 1 if froze_factor == 0 else base_scale / grid.shape[0]
+                        grid.sh_data.grad *= 1 if froze_factor == 0 else base_scale / grid.shape[0]
+                propagate_multiscale_to_leaves(grid, previous_grids, args, sh_dim_per_scale, 
+                                               grad=True)
+                #if args.debug:
+                #    grads_after = tonumpy(grid.density_data.grad)
+
             if gstep_id >= args.lr_fg_begin_step:
                 grid.optim_density_step(lr_sigma, beta=args.rms_beta, optim=args.sigma_optim)
                 grid.optim_sh_step(lr_sh, beta=args.rms_beta, optim=args.sh_optim)
@@ -601,17 +811,24 @@
                 elif grid.basis_type == svox2.BASIS_TYPE_MLP:
                     optim_basis_mlp.step()
                     optim_basis_mlp.zero_grad()
-
-    train_step()
+        
+        debug_info['all_depth_imgs'] = all_depth_imgs
+        
+        for stat_name in epoch_stats:
+            epoch_stats[stat_name] = epoch_stats[stat_name] / batches_per_epoch
+        return epoch_stats, debug_info
+    
+    train_stats, debug_info = train_step()
+    all_train_stats.append(('epoch_{}'.format(epoch_id), train_stats))
+    
     gc.collect()
     gstep_id_base += batches_per_epoch
 
     #  ckpt_path = path.join(args.train_dir, f'ckpt_{epoch_id:05d}.npz')
     # Overwrite prev checkpoints since they are very huge
-    if args.save_every > 0 and (epoch_id + 1) % max(
-            factor, args.save_every) == 0 and not args.tune_mode:
-        print('Saving', ckpt_path)
-        grid.save(ckpt_path)
+    #if args.save_every > 0 and (epoch_id + 1) % max(factor, args.save_every) == 0 and not args.tune_mode:
+    #    print('Saving', ckpt_path)
+    #    grid.save(ckpt_path)
 
     if (gstep_id_base - last_upsamp_step) >= args.upsamp_every:
         last_upsamp_step = gstep_id_base
@@ -628,19 +845,51 @@
             reso_id += 1
             use_sparsify = True
             z_reso = reso_list[reso_id] if isinstance(reso_list[reso_id], int) else reso_list[reso_id][2]
-            grid.resample(reso=reso_list[reso_id],
-                    sigma_thresh=args.density_thresh,
-                    weight_thresh=args.weight_thresh / z_reso if use_sparsify else 0.0,
-                    dilate=2, #use_sparsify,
-                    cameras=resample_cameras if args.thresh_type == 'weight' else None,
-                    max_elements=args.max_grid_elements)
 
+            previous_grid = deepcopy(grid)
+            print("Resampling with max_elements={}".format(args.max_grid_elements))
+            grid.resample(reso=reso_list[reso_id],
+                          sigma_thresh=args.density_thresh,
+                          weight_thresh=args.weight_thresh / z_reso if use_sparsify else 0.0,
+                          dilate=0 if args.no_dilate else 2, #use_sparsify,
+                          cameras=resample_cameras if args.thresh_type == 'weight' else None,
+                          max_elements=args.max_grid_elements)
+            print("Finished resampling")
             if grid.use_background and reso_id <= 1:
                 grid.sparsify_background(args.background_density_thresh)
+            if args.debug:
+                display_grid(grid, title='active_nodes_scale_{}_{}_{}'.format(*grid.shape[:3]))        
 
             if args.upsample_density_add:
                 grid.density_data.data[:] += args.upsample_density_add
 
+            if MULTISCALE:
+                with torch.no_grad():
+                    p_to_current_map = get_links_to_current_grid(grid, previous_grid)
+
+                    # TODO: to save memory, maybe only store sh_data/ (not the full previous_grid), 
+                    # and remove data/links that are not longer relevant for current grid
+                    # this can be easily done by resorting using the links.
+                    # Although this is useful for visualization.
+                    if args.preserve_previous_grid_values:
+                        propagate_multiscale_to_parents(previous_grid, previous_grids, args, sh_dim_per_scale)
+                        previous_grids = [(previous_grid,p_to_current_map), *previous_grids]
+                        grid.density_data.data[:] = 0
+                        grid.sh_data.data[:] = 0
+                        grid.multiscale = False
+                    else:                        
+                        if args.multiscale:
+                            previous_grid.sh_data -= previous_grid.sh_data
+                        if args.multiscale_sigma:
+                            previous_grid.density_data -= previous_grid.density_data
+                        previous_grids = [(previous_grid,p_to_current_map), *previous_grids]
+
+                        propagate_multiscale_to_parents(grid, previous_grids, args, sh_dim_per_scale)
+                        # replace values of previous grids 
+                    propagate_multiscale_to_leaves(grid, previous_grids, args, sh_dim_per_scale)
+
+                a = 1
+                
         if factor > 1 and reso_id < len(reso_list) - 1:
             print('* Using higher resolution images due to large grid; new factor', factor)
             factor //= 2
@@ -649,11 +898,24 @@
 
     if gstep_id_base >= args.n_iters:
         print('* Final eval and save')
-        eval_step()
+        final_eval_stats = eval_step()
+        all_eval_stats.append(('final', final_eval_stats))
         global_stop_time = datetime.now()
         secs = (global_stop_time - global_start_time).total_seconds()
         timings_file = open(os.path.join(args.train_dir, 'time_mins.txt'), 'a')
         timings_file.write(f"{secs / 60}\n")
         if not args.tune_nosave:
             grid.save(ckpt_path)
+            for (previous_grid,p_to_current_map) in previous_grids:
+                previous_grid.save(ckpt_path.replace('.npz', f'_prev_{str(previous_grid.shape[0]).zfill(4)}.npz'))
+                dump_to_pickle(ckpt_path.replace('.npz', f'_prev_{str(previous_grid.shape[0]).zfill(4)}_map.pkl'), 
+                            p_to_current_map)
+            
+        things_to_save = dict(final_eval_stats=final_eval_stats,
+                              all_eval_stats=all_eval_stats,
+                              all_train_stats=all_train_stats,
+                              n_train_images=dset.n_images,
+                              n_eval_images=dset_test.n_images,
+                              args=args)
+        dump_to_pickle(results_path, things_to_save)
         break
